{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from pytorch_tabular.utils import print_metrics, load_covertype_dataset\n",
    "np.random.seed(42)\n",
    "random.seed(42)\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Loading CoverType Dataset\n",
    "\n",
    "Predicting forest cover type from cartographic variables only (no remotely sensed data). The actual forest cover type for a given observation (30 x 30 meter cell) was determined from US Forest Service (USFS) Region 2 Resource Information System (RIS) data. Independent variables were derived from data originally obtained from US Geological Survey (USGS) and USFS data. Data is in raw form (not scaled) and contains binary (0 or 1) columns of data for qualitative independent variables (wilderness areas and soil types).\n",
    "\n",
    "This study area includes four wilderness areas located in the Roosevelt National Forest of northern Colorado. These areas represent forests with minimal human-caused disturbances, so that existing forest cover types are more a result of ecological processes rather than forest management practices.\n",
    "\n",
    "It is from [UCI ML Repository](https://archive.ics.uci.edu/ml/datasets/covertype)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names, target_name = load_covertype_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unlabelled Data: 580430 rows | Labelled Data: 582\n"
     ]
    }
   ],
   "source": [
    "# Splitting data into data for SSL and data for finetuning\n",
    "ssl, finetune = train_test_split(data, random_state=42, test_size=0.001)\n",
    "# Train and val splits\n",
    "ssl_train, ssl_val = train_test_split(ssl, random_state=42)\n",
    "finetune_train, finetune_test = train_test_split(finetune, random_state=42)\n",
    "finetune_train, finetune_val = train_test_split(finetune_train, random_state=42)\n",
    "print(f\"Unlabelled Data: {ssl.shape[0]} rows | Labelled Data: {finetune.shape[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import CategoryEmbeddingModelConfig\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig\n",
    "from pytorch_tabular.ssl_models.dae import DenoisingAutoEncoderConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Self-Supervised Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An excerpt from the article by Yann LeCun and Ishan Mishra from Meta will serve as a good introduction here:\n",
    "> Supervised learning is a bottleneck for building more intelligent generalist models that can do multiple tasks and acquire new skills without massive amounts of labeled data. Practically speaking, it’s impossible to label everything in the world. There are also some tasks for which there’s simply not enough labeled data, such as training translation systems for low-resource languages.\n",
    "\n",
    "> As babies, we learn how the world works largely by observation. We form generalized predictive models about objects in the world by learning concepts such as object permanence and gravity. Later in life, we observe the world, act on it, observe again, and build hypotheses to explain how our actions change our environment by trial and error.\n",
    "\n",
    ">Common sense helps people learn new skills without requiring massive amounts of teaching for every single task. For example, if we show just a few drawings of cows to small children, they’ll eventually be able to recognize any cow they see. By contrast, AI systems trained with supervised learning require many examples of cow images and might still fail to classify cows in unusual situations, such as lying on a beach. How is it that humans can learn to drive a car in about 20 hours of practice with very little supervision, while fully autonomous driving still eludes our best AI systems trained with thousands of hours of data from human drivers? The short answer is that humans rely on their previously acquired background knowledge of how the world works.\n",
    "\n",
    "> How do we get machines to do the same?\n",
    "\n",
    ">We believe that self-supervised learning (SSL) is one of the most promising ways to build such background knowledge and approximate a form of common sense in AI systems.\n",
    "\n",
    "[Full Article](https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/) is a very interesting read.\n",
    "\n",
    "SSL has been very successfully used in NLP (All the Large Language Models which create magic is learnt through SSL), and with some success in Computer Vision. But can we do that with Tabular data? The answer is yes.\n",
    "\n",
    "There are many research papers which talk about such models:\n",
    "\n",
    "1. [TabNet](https://arxiv.org/pdf/1908.07442v5.pdf) talks about how it can be used for SSL\n",
    "2. [VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain](https://proceedings.neurips.cc//paper/2020/file/7d97667a3e056acab9aaf653807b4a03-Paper.pdf) also proposes another SSL model for tabular\n",
    "3. And so does [SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning](https://arxiv.org/pdf/2110.04361v2.pdf)\n",
    "\n",
    "And before all these, there was [Denoising AutoEncoder](https://www.kaggle.com/code/faisalalsrheed/denoising-autoencoders-dae-for-tabular-data) which was used for winning solutions in many Tabular Playground competitions.\n",
    "\n",
    "PyTorch Tabular has provided the Denoising AutoEncoder implementation inspired from [https://github.com/ryancheunggit/tabular_dae](https://github.com/ryancheunggit/tabular_dae). Let's see how we can use that.\n",
    "\n",
    "A Denoising AutoEncoder has the below architecture ([Source](https://towardsdatascience.com/generating-images-with-autoencoders-77fd3a8dd368)):\n",
    "\n",
    "<img src=\"../imgs/auto_encoder.png\" alt=\"DAE\" width=\"500\"/>\n",
    "\n",
    "We corrupt the input on the left and we ask the model to learn to predict the orginal, denoised input. By this process, the network is forced to learn a compressed bottleneck (labelled `code`) which captures most of the characteristics of the input data, i.e. a robust representation of the input data.\n",
    "\n",
    "In the `DenoisingAutoencoder` implementation in PyTorchTabular, the noise is introduced in two ways:\n",
    "1. `swap` - In this strategy, noise is introduced by replacing a value in a feature with another value of the same feature, randomly sampled from the rest of the rows.\n",
    "\n",
    "2. `zero` - In here, noise is introduced by just replacing the value with zero.\n",
    "\n",
    "In addition to that, we also can set `noise_probabilities`, with which we can define the probability with which noise will be introduced to a feature. We can set this parameter as a dictionary of the form, `{featurename: noise_probability}`. Or we can also set a single probability for all the features easily by using `default_noise_probability`\n",
    "\n",
    "Now once we have this robust representation, we can use this representation in other downstream tasks like regression or classification.\n",
    "\n",
    "A typical SSL workflow would have a large dataset without labels, and a smaller dataset with labels for finetuning.\n",
    "\n",
    "1. We start with Pre-training using unlabelled data\n",
    "2. Then we use the pre-trained model (the `code`in the diagram) for a dowstream task like `regression` or `classification`  \n",
    "    i. We create a new model with the pretrained model as the backbone and a head for prediction  \n",
    "    ii. We train the new model (finetune) on small labelled data\n",
    "    \n",
    "This approach would typically work better than a purely supervised model using just the small labelled dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fully Supervised Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first train a fully supervised model using just the ~5k rows of labelled data. This can be used as a baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[target_name],\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    "    normalize_continuous_features=True,\n",
    ")\n",
    "\n",
    "trainer_config = TrainerConfig(\n",
    "    batch_size=2048,\n",
    "    max_epochs=1000,\n",
    "    early_stopping=\"valid_loss\", # Turning off Early Stopping\n",
    "    checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
    "    load_best=True, # After training, load the best checkpoint\n",
    ")\n",
    "\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"2000-1000\",\n",
    "    activation=\"ReLU\",\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    "    head=\"LinearHead\",\n",
    "    head_config={\n",
    "        \"layers\": \"\",\n",
    "        \"activation\": \"ReLU\",\n",
    "    },\n",
    "    learning_rate = 1e-3\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  2.1 M │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │    427 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │  7.0 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  2.1 M │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │    427 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │  7.0 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 2.1 M                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 2.1 M                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 8                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 2.1 M                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 2.1 M                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 8                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6472af88fc214632b3c9936040689590",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.p\n",
       "y:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a\n",
       "lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.p\n",
       "y:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a\n",
       "lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f3eee9967d0>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tabular_model.fit(\n",
    "    train=finetune_train, \n",
    "    validation=finetune_val\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0d8fff8c8664e44b8a2030234d657a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6438356041908264     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     1.331926703453064     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6438356041908264    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    1.331926703453064    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = tabular_model.evaluate(finetune_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = result[0]\n",
    "result['Type'] = \"Supervised\"\n",
    "results.append(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Denoising AutoEncoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Pre-training\n",
    "\n",
    "Now, we use ~575k unlabelled data to do self-supervised learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 2048\n",
    "steps_per_epoch = int(ssl_train.shape[0]/batch_size)\n",
    "epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "Collapsed": "false",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ssl_data_config = DataConfig(\n",
    "    target=None, #Setting target as None because we don't need the target for SSL\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    "    normalize_continuous_features=True,\n",
    "    handle_missing_values=False, # For SSL tasks, missing values and unknwon categories will not be handled automatically\n",
    "    handle_unknown_categories=False, # Not Setting these configs to False will throw and error when initializing TabularModel \n",
    ")\n",
    "\n",
    "ssl_trainer_config = TrainerConfig(\n",
    "    batch_size=batch_size,\n",
    "    max_epochs=epochs,\n",
    "    early_stopping=\"valid_loss\", # Turning off Early Stopping\n",
    "    checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
    "    load_best=True, # After training, load the best checkpoint\n",
    ")\n",
    "\n",
    "# Setting OneCycleLR schedule\n",
    "ssl_optimizer_config = OptimizerConfig(\n",
    "    lr_scheduler=\"OneCycleLR\",\n",
    "    lr_scheduler_params={\n",
    "        \"max_lr\":1e-2, \n",
    "        \"epochs\": epochs, \n",
    "        \"steps_per_epoch\":steps_per_epoch\n",
    "    }\n",
    ")\n",
    "\n",
    "# Setting the encoder config\n",
    "encoder_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"backbone\",\n",
    "    layers=\"4000-2000-1000-512\",  # Number of nodes in each layer\n",
    "    activation=\"ReLU\",  # Activation between each layers\n",
    "    head=None, # If not set to None, it will throw a warning\n",
    ")\n",
    "\n",
    "# Setting the decoder config.\n",
    "# NOTE: the last dimension in encoder layers should be first dimension in decoder layers\n",
    "# i.e. last encoder layer dim = 512, first decoder layer dim = 512\n",
    "decoder_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"backbone\",\n",
    "    layers=\"512-2048-4096\",  # Number of nodes in each layer\n",
    "    activation=\"ReLU\",  # Activation between each layers\n",
    "    head=None, # If not set to None, it will throw a warning\n",
    ")\n",
    "\n",
    "# DAE Config. No need to set task because it is hardcoded to SSL\n",
    "# Can't set any loss or metrics as well because for the SSL task\n",
    "# (especially for DAE), the loss and metrics are fixed.\n",
    "ssl_model_config = DenoisingAutoEncoderConfig(\n",
    "    noise_strategy=\"swap\", # Can be 'zero' as well. Defines if noise is swapping features or making features zero\n",
    "    default_noise_probability = 0.7, # Probability of corruption by noise\n",
    "    include_input_features_inference=True, # Setting this to True will include the input features in the concatenated layers in the decoder\n",
    "    encoder_config=encoder_config, \n",
    "    decoder_config=decoder_config, \n",
    "    learning_rate=1e-3)\n",
    "\n",
    "ssl_tabular_model = TabularModel(\n",
    "    data_config=ssl_data_config,\n",
    "    model_config=ssl_model_config,\n",
    "    optimizer_config=ssl_optimizer_config,\n",
    "    trainer_config=ssl_trainer_config,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name                </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                           </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ encoder             │ CategoryEmbeddingBackbone      │ 10.7 M │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ decoder             │ CategoryEmbeddingBackbone      │  9.7 M │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _featurizer         │ DenoisingAutoEncoderFeaturizer │ 10.7 M │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ _embedding          │ MixedEmbedding1dLayer          │    896 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 4 </span>│ reconstruction      │ MultiTaskHead                  │  139 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 5 </span>│ mask_reconstruction │ Linear                         │  139 K │\n",
       "└───┴─────────────────────┴────────────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName               \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                          \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ encoder             │ CategoryEmbeddingBackbone      │ 10.7 M │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ decoder             │ CategoryEmbeddingBackbone      │  9.7 M │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _featurizer         │ DenoisingAutoEncoderFeaturizer │ 10.7 M │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ _embedding          │ MixedEmbedding1dLayer          │    896 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m4\u001b[0m\u001b[2m \u001b[0m│ reconstruction      │ MultiTaskHead                  │  139 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m5\u001b[0m\u001b[2m \u001b[0m│ mask_reconstruction │ Linear                         │  139 K │\n",
       "└───┴─────────────────────┴────────────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 20.6 M                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 20.6 M                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 82                                                                         \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 20.6 M                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 20.6 M                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 82                                                                         \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cba7c0e9a3ea40599e0cafb02776b3e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f3eeead9f90>"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Pretraining\n",
    "ssl_tabular_model.pretrain(train=ssl_train, validation=ssl_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-Tuning\n",
    "\n",
    "Now we create a finetune model using the pretrained weights, and funetune the model for the classification task using the ~5k labelled data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_trainer_config = TrainerConfig(\n",
    "    batch_size=2048,\n",
    "    max_epochs=1000,\n",
    "    early_stopping=\"valid_loss\", # Turning off Early Stopping\n",
    "    checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
    "    load_best=True, # After training, load the best checkpoint\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_optimizer import QHAdam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_optimizer_config = OptimizerConfig(\n",
    "    lr_scheduler=\"OneCycleLR\",\n",
    "    lr_scheduler_params={\n",
    "        \"max_lr\":1e-3, \n",
    "        \"epochs\": epochs, \n",
    "        \"steps_per_epoch\":steps_per_epoch\n",
    "    }\n",
    ")\n",
    "\n",
    "finetune_model = ssl_tabular_model.create_finetune_model(\n",
    "    task=\"classification\",\n",
    "    train=finetune_train, \n",
    "    validation=finetune_val, \n",
    "    target=[target_name], #Provide the column name of the target as a list\n",
    "    head=\"LinearHead\",\n",
    "    head_config={\n",
    "        \"layers\": \"256-64\",\n",
    "        \"activation\": \"ReLU\",\n",
    "    },\n",
    "    trainer_config=ft_trainer_config,# Overriding previous trainer config\n",
    "    optimizer_config=ft_optimizer_config,\n",
    "    optimizer=QHAdam, # Using a custom optimizer\n",
    "    optimizer_params={\"nus\": (0.7, 1.0), \"betas\": (0.95, 0.998)}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if the new model has the pretrained weights\n",
    "import torch\n",
    "assert torch.equal(ssl_tabular_model.model.encoder.linear_layers[0].weight, finetune_model.model._backbone.encoder.linear_layers[0].weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name        </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ custom_loss │ CrossEntropyLoss          │      0 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _backbone   │ DenoisingAutoEncoderModel │ 20.6 M │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head       │ LinearHead                │  156 K │\n",
       "└───┴─────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName       \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ custom_loss │ CrossEntropyLoss          │      0 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _backbone   │ DenoisingAutoEncoderModel │ 20.6 M │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head       │ LinearHead                │  156 K │\n",
       "└───┴─────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 156 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 20.6 M                                                                                       \n",
       "<span style=\"font-weight: bold\">Total params</span>: 20.8 M                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 83                                                                         \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 156 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 20.6 M                                                                                       \n",
       "\u001b[1mTotal params\u001b[0m: 20.8 M                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 83                                                                         \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37fb066c98bf4f588d6daf6f6aa0d7e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.p\n",
       "y:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a\n",
       "lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.p\n",
       "y:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a\n",
       "lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f3eeea3b150>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_model.finetune(\n",
    "    freeze_backbone=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "Collapsed": "false",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "512503361cd641c196803ade9f493bd7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6095890402793884     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9251274466514587     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6095890402793884    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9251274466514587    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = finetune_model.evaluate(finetune_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = result[0]\n",
    "result['Type'] = \"Self-Supervised\"\n",
    "results.append(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the Self-Supervised approach has better lower loss. This is not a definite phenomenon and it depends on how well we were able to learn the representation during Pre-training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_loss</th>\n",
       "      <th>test_accuracy</th>\n",
       "      <th>Type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.331927</td>\n",
       "      <td>0.643836</td>\n",
       "      <td>Supervised</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.925127</td>\n",
       "      <td>0.609589</td>\n",
       "      <td>Self-Supervised</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   test_loss  test_accuracy             Type\n",
       "0   1.331927       0.643836       Supervised\n",
       "1   0.925127       0.609589  Self-Supervised"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
